{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d67a4729-cd2f-47e7-a4f6-f84a5677414f",
   "metadata": {},
   "source": [
    "# **05:** RAG (Retrieval Augmented Generation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b880d1ed-3db0-45a1-807e-1b47e9ce1320",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "# ! pip install faiss-cpu \"mistralai>=0.1.2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b350586c-f8ca-4013-8840-46460e4ba295",
   "metadata": {},
   "source": [
    "### Load API key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4b100be-c2cf-4e07-ba17-07eae31aea35",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from helper import load_mistral_api_key\n",
    "api_key, dlai_endpoint = load_mistral_api_key(ret_key=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe8609d5-9f27-4202-b0be-36db34412998",
   "metadata": {},
   "source": [
    "### Get data\n",
    "\n",
    "- You can go to https://www.deeplearning.ai/the-batch/\n",
    "- Search for any article and copy its URL."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "983ce5f6-5eb1-4442-8e04-822bdbd701f4",
   "metadata": {},
   "source": [
    "### Parse the article with BeautifulSoup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c01740-72b4-482c-b61e-e272a734f01f",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "import requests\n",
    "from bs4 import BeautifulSoup\n",
    "import re\n",
    "\n",
    "response = requests.get(\n",
    "    \"https://www.deeplearning.ai/the-batch/a-roadmap-explores-how-ai-can-detect-and-mitigate-greenhouse-gases/\"\n",
    ")\n",
    "html_doc = response.text\n",
    "soup = BeautifulSoup(html_doc, \"html.parser\")\n",
    "tag = soup.find(\"div\", re.compile(\"^prose--styled\"))\n",
    "text = tag.text\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "491cf036-7bc9-4e96-94ab-737872de531a",
   "metadata": {},
   "source": [
    "### Optionally, save the text into a text file\n",
    "- You can upload the text file into a chat interface in the next lesson.\n",
    "- To download this file to your own machine, click on the \"Jupyter\" logo to view the file directory.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fbfa8e2-08af-445b-8134-7395cc693b5b",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "file_name = \"AI_greenhouse_gas.txt\"\n",
    "with open(file_name, 'w') as file:\n",
    "    file.write(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aad1aa61-9e1c-46c8-ae5e-61855df440f9",
   "metadata": {},
   "source": [
    "### Chunking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8494655e-bd87-49de-8f1d-69ffbc1c256e",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "chunk_size = 512\n",
    "chunks = [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c78c9936-0c1d-471c-b030-6c45639e7238",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "len(chunks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e42e3f06-09d6-4186-be0b-6034b0c8330e",
   "metadata": {},
   "source": [
    "### Get embeddings of the chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e77d9805-7a53-4210-9f80-f4de52285588",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from mistralai.client import MistralClient\n",
    "\n",
    "\n",
    "def get_text_embedding(txt):\n",
    "    client = MistralClient(api_key=api_key, endpoint=dlai_endpoint)\n",
    "    embeddings_batch_response = client.embeddings(model=\"mistral-embed\", input=txt)\n",
    "    return embeddings_batch_response.data[0].embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46503830-6ad5-493e-a629-152721e2d88e",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55396758-c3f3-45b3-b6e7-d4912c0899f2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "text_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca875993-fe6d-42df-811e-a43891cd0350",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "len(text_embeddings[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cba33c7-9d1d-44d8-a01e-e30f16be1aac",
   "metadata": {},
   "source": [
    "### Store in a vector databsae\n",
    "- In this classroom, you'll use [Faiss](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0981a956-5f9c-4ea6-8fb3-a2cc9fe896d2",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "import faiss\n",
    "\n",
    "d = text_embeddings.shape[1]\n",
    "index = faiss.IndexFlatL2(d)\n",
    "index.add(text_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ee023ab-b26c-4df5-8a7b-7dd660bfad86",
   "metadata": {},
   "source": [
    "### Embed the user query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894d9764-9da9-4629-8f2a-c9dcaf6ceb8d",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "question = \"What are the ways that AI can reduce emissions in Agriculture?\"\n",
    "question_embeddings = np.array([get_text_embedding(question)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c4948cc-6d8b-449f-bc00-abb3591c7222",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "question_embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15989e10-d0ec-41be-b6be-fa317565a926",
   "metadata": {},
   "source": [
    "### Search for chunks that are similar to the query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c930b378-7aac-434c-881b-ab69d3edb93d",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "D, I = index.search(question_embeddings, k=2)\n",
    "print(I)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73aab584-1dbf-4532-b41e-0403eeeeb567",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "retrieved_chunk = [chunks[i] for i in I.tolist()[0]]\n",
    "print(retrieved_chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da042a53-4564-4057-9a60-9b57dffff6a1",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "prompt = f\"\"\"\n",
    "Context information is below.\n",
    "---------------------\n",
    "{retrieved_chunk}\n",
    "---------------------\n",
    "Given the context information and not prior knowledge, answer the query.\n",
    "Query: {question}\n",
    "Answer:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e7661e-51e2-4148-a44c-f262e7e85d56",
   "metadata": {
    "height": 268
   },
   "outputs": [],
   "source": [
    "from mistralai.models.chat_completion import ChatMessage\n",
    "\n",
    "\n",
    "def mistral(user_message, model=\"mistral-small-latest\", is_json=False):\n",
    "    client = MistralClient(api_key=api_key, endpoint=dlai_endpoint)\n",
    "    messages = [ChatMessage(role=\"user\", content=user_message)]\n",
    "\n",
    "    if is_json:\n",
    "        chat_response = client.chat(\n",
    "            model=model, messages=messages, response_format={\"type\": \"json_object\"}\n",
    "        )\n",
    "    else:\n",
    "        chat_response = client.chat(model=model, messages=messages)\n",
    "\n",
    "    return chat_response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a964d3-0dea-422a-83e6-342a4ab6906b",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "response = mistral(prompt)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ce2ffa-c9bf-48a4-a6f3-6727da6a882d",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a653b9c2-d6e7-42f5-88e9-d5dcd376e61e",
   "metadata": {},
   "source": [
    "## RAG + Function calling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41aac3a-20b4-4e33-ac58-f245577141f8",
   "metadata": {
    "height": 455
   },
   "outputs": [],
   "source": [
    "def qa_with_context(text, question, chunk_size=512):\n",
    "    # split document into chunks\n",
    "    chunks = [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]\n",
    "    # load into a vector database\n",
    "    text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])\n",
    "    d = text_embeddings.shape[1]\n",
    "    index = faiss.IndexFlatL2(d)\n",
    "    index.add(text_embeddings)\n",
    "    # create embeddings for a question\n",
    "    question_embeddings = np.array([get_text_embedding(question)])\n",
    "    # retrieve similar chunks from the vector database\n",
    "    D, I = index.search(question_embeddings, k=2)\n",
    "    retrieved_chunk = [chunks[i] for i in I.tolist()[0]]\n",
    "    # generate response based on the retrieve relevant text chunks\n",
    "\n",
    "    prompt = f\"\"\"\n",
    "    Context information is below.\n",
    "    ---------------------\n",
    "    {retrieved_chunk}\n",
    "    ---------------------\n",
    "    Given the context information and not prior knowledge, answer the query.\n",
    "    Query: {question}\n",
    "    Answer:\n",
    "    \"\"\"\n",
    "    response = mistral(prompt)\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb4467f-0db8-4247-8150-8746a4630432",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "I.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b1bcc8d-b957-4167-b1e9-1353a6301762",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "I.tolist()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23d8ef9-36d4-4912-8303-d2fe3860d7c6",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "names_to_functions = {\"qa_with_context\": functools.partial(qa_with_context, text=text)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae3717b-37e6-40b3-93b1-cfd023b59079",
   "metadata": {
    "height": 336
   },
   "outputs": [],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"qa_with_context\",\n",
    "            \"description\": \"Answer user question by retrieving relevant context\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"question\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"user question\",\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"question\"],\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2e442fa-5cca-4eb1-9c3f-24276fe4f75c",
   "metadata": {
    "height": 251
   },
   "outputs": [],
   "source": [
    "question = \"\"\"\n",
    "What are the ways AI can mitigate climate change in transportation?\n",
    "\"\"\"\n",
    "\n",
    "client = MistralClient(api_key=api_key, endpoint=dlai_endpoint)\n",
    "\n",
    "response = client.chat(\n",
    "    model=\"mistral-large-latest\",\n",
    "    messages=[ChatMessage(role=\"user\", content=question)],\n",
    "    tools=tools,\n",
    "    tool_choice=\"any\",\n",
    ")\n",
    "\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d349dd7-0138-4857-9bcb-69ed151cb1b8",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "tool_function = response.choices[0].message.tool_calls[0].function\n",
    "tool_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca751c08-e6e7-46a4-8e4c-a30407853cfc",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tool_function.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08910b72-2aaa-4393-a35a-5ed2671b8faf",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "args = json.loads(tool_function.arguments)\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "409f6a67-2787-424e-8b8d-92fc9b66bdf9",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "function_result = names_to_functions[tool_function.name](**args)\n",
    "function_result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f2d1982-a899-4ad5-a5de-2a33d46cd311",
   "metadata": {},
   "source": [
    "## More about RAG\n",
    "To learn about more advanced chunking and retrieval methods, you can check out:\n",
    "- [Advanced Retrieval for AI with Chroma](https://learn.deeplearning.ai/courses/advanced-retrieval-for-ai/lesson/1/introduction)\n",
    "  - Sentence window retrieval\n",
    "  - Auto-merge retrieval\n",
    "- [Building and Evaluating Advanced RAG Applications](https://learn.deeplearning.ai/courses/building-evaluating-advanced-rag)\n",
    "  - Query Expansion\n",
    "  - Cross-encoder reranking\n",
    "  - Training and utilizing Embedding Adapters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9932106-163e-45f4-85db-d6b373cf5bbd",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:light"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
